import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class Solution1074 {
    class SubMatrix{
        int startRow;
        int startCol;
        int endCol;

        public SubMatrix(int startRow, int startCol, int endCol) {
            this.startRow = startRow;
            this.startCol = startCol;
            this.endCol = endCol;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof SubMatrix)) return false;
            SubMatrix subMatrix = (SubMatrix) o;
            return startRow == subMatrix.startRow &&
                    startCol == subMatrix.startCol &&
                    endCol == subMatrix.endCol;
        }

        @Override
        public int hashCode() {
            return Objects.hash(startRow, startCol, endCol);
        }
    }

    public int[][] prefix;
    public Map<SubMatrix,Integer> mark=new HashMap<>();

    public int getInterval(int iX,int iY,int jX,int jY){
        SubMatrix tmpMatrix=new SubMatrix(iX,iY,jY);
        Integer lastRes=mark.getOrDefault(tmpMatrix,null);
        Integer tmpRes;
        int[] tmpPre=prefix[jX];
        int left=iY-1<0? 0:tmpPre[iY-1];
        int right=tmpPre[jY];
        if(lastRes==null){
            tmpRes=(right-left);
        }
        else{
            tmpRes=lastRes+(right-left);
        }
        mark.put(tmpMatrix,tmpRes);
        return tmpRes;
    }

    public int numSubmatrixSumTarget(int[][] matrix, int target) {
        int m=matrix.length;
        int n=matrix[0].length;
        int count=0;
        prefix=new int[m][n];
        for (int i = 0; i < m; i++) {
            int[] tmpArr=matrix[i];
            int[] tmpPre=prefix[i];
            int tmpSum=0;
            for (int j = 0; j < n; j++) {
                tmpSum+=tmpArr[j];
                tmpPre[j]=tmpSum;
            }
        }
        for (int i = 0; i < m*n; i++) {
            int iX=i/n;
            int iY=i%n;
            for (int j = i; j < m*n; j++) {
                int jX=j/n;
                int jY=j%n;
                if(jY<iY){continue;}
                if(getInterval(iX,iY,jX,jY)==target){
                    count++;
                }
            }
        }
        return count;
    }

    public static void main(String[] args) {
        Solution1074 s=new Solution1074();
        System.out.println(s.numSubmatrixSumTarget(
                new int[][] {{0,1,0},{1,1,1},{0,1,0}},
        0));
    }
}
